/**
 * 
 */
package edu.umd.clip.lm.ngram;

import java.io.IOException;
import java.util.*;

import edu.umd.clip.lm.factors.*;
import edu.umd.clip.lm.factors.Dictionary;
import edu.umd.clip.lm.model.Experiment;
import edu.umd.clip.lm.model.data.*;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class TrainingDataGibbsSampler {
	private final NgramModel models[];
	private final int order;
	
	public TrainingDataGibbsSampler(NgramModel[] models) {
		this.models = models;

		Experiment exp = Experiment.getInstance();
		this.order = exp.getLM().getOrder();
	}

	private int[] makeInitialSample() {
		int tuple[] = new int[models.length];
		return tuple;
	}
	
	public void doSampling(long sampleSize, WritableTrainingData data) throws IOException {
	}
	
	public int drawOneSample(int currentModel, Context currentContext, FactorTuple futureTuple) {
		final NgramModel model = models[currentModel];
		final CtxVar contextVariables[] = model.getContextVariables();
		final long data[] = currentContext.data;
		
		final FactorTupleDescription desc = Experiment.getInstance().getTupleDescription();
		final byte mainFactorIndex = desc.getMainFactorIndex();
		
		int ctx[] = new int[contextVariables.length];
		
		for(int i=0; i<ctx.length; ++i) {
			ctx[i] = contextVariables[i].getValue(data, futureTuple.getBits());
		}
		
		int vocab[] = model.getFutureVocab(ctx); 

		double cumulativeDistribution[] = new double[vocab.length];
		
		double sum = 0;
		for(int i=0; i<vocab.length; ++i) {
			double prob = model.getProb(vocab[i], ctx);
			sum += prob;
			cumulativeDistribution[i] = sum;
		}

		Random rnd = new Random();
		
		int sample;
		
		dosample: while(true) {
			double prob = rnd.nextDouble() * sum;
			int idx = Arrays.binarySearch(cumulativeDistribution, prob);
			if (idx < 0) idx = - idx - 1;
			if (idx == vocab.length) --idx;
			sample = vocab[idx];
			
			// check if the tuple is valid
			// the sample is invalid if where is a word before <s> in the tuple

			// TODO: this is wrong!
			// vocab contains *packed* factors!!
			if (Dictionary.isStart(sample)) {
				// take another sample if there is a non-start before current
				for(int i=-currentContext.getOrder()+1; i<model.getFutureVariable().getOffset(); ++i) {
					if (!Dictionary.isStart(FactorTuple.getValue(data[data.length+i], mainFactorIndex))) {
						continue dosample;
					}
				}
				data[data.length+model.getFutureVariable().getOffset()] = desc.createStartTuple();
				// re-start sampling from the last model
				// because all previous (in time) models must generate <s> as well
				return models.length - 1;
			}
			break;
		}
		return currentModel - 1;
	}
}
